# Clear
rm(list = ls())

# Set WD
setwd("conflict_prediction_replication_pkg")

# Load Packages
source("config_code/directories.R")
source("config_code/package_requirements.R")
source("config_code/helper_functions.R")
source("config_code/model_parameters.R")

cl <- makeCluster(10)
registerDoParallel(cl)

args <- commandArgs()

table <- args[7]

## Run Prediction Algorithms for Table 1 and 3
if (is.element(table, c('table1', 'table3_pt1', 'table3_pt2', 'table3_pt3',
                        'table3_pt4', 'figures_1_2_annual', 'figures_1_2_slow',
                        'figures_1_2_fixed'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) { # {any, high, spike}
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      
      for (rhs.group in names(get(table))) { # table1 table3_pt1 - table3_pt4 figures_1_2_annual
                                             # figures_1_2_slow figures_1_2_fixed
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}

if ( is.element(table, c('table2')) ) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c("any","high", "spike")) { # {any, high, spike}
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      rhs <- full_vars
      for (algo in c("lagdv",
                     "ols",
                     "ols_fe",
                     "ols_aggfe"
      )) {
        print(paste(v,algo,sep="---"))
        source(paste("estimation_code/benchmarks/conflict_",algo,".R",sep = ""))
      } 
    }
  }
}

# Run cross sectional predictions
if ( is.element(table, c('table4', 'figure_2')) ) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) { # {any, high, spike}
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      source("config_code/split_cross_section.R")
      for (rhs.group in names(get(table))) {
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/cross/conflict_",algo,"_cross.R",sep = ""))
        } 
      }
    }
  }
}

